#!/usr/bin/python
# https://stackoverflow.com/questions/6908143/should-i-put-shebang-in-python-scripts-and-what-form-should-it-take


import math
import numpy as np
import os
import itertools
import sys
import time

# https://stackoverflow.com/questions/9819733/scipy-special-import-issue
# https://docs.scipy.org/doc/scipy/reference/api.html#guidelines-for-importing-functions-from-scipy
#import scipy
from scipy import special


###### General parameters and Mnova data import
samp_limit = int(1e4)


# dataset must contain file "multiplets1.txt" at the expno level
# with the centres of each region (currently in points), and the peak amplitude

# output will be a file "best_sampling.txt" in the same place as the multiplets
# file, with the approriate NUS schedule
# this must be dealt with by the AU program in question to put it into the right place

if len(sys.argv) == 1:
    print 'No arguments supplied - using test data...'
    shift = [5646., 5440.3, 5308., 5246., 5188., 5089., 4909., 4590., 4250.127, 3557., 3531., 3337.5]
    a = [2., 8., 1., 1., 3., 1., 2., 1., 0., 1., 1., 1.]
    npo = 8184
    fwhm = 0.0006027947292202448
    chunk = 88
    factor = 0.2
    dataset_folder = os.getcwd();
    
else:
    # arguments
    # sys.argv[0] = name of (this) script
    # sys.argv[1] = source of multiplet list (also destination folder for NUS schedule)
    # sys.argv[2] = nominal linewidth
    # sys.argv[3] = total number of points (after reconstruction) - 2*L31*1TD
    # sys.argv[4] = chunk size
    # sys.argv[5] = NUS factor #do we have to come back and set it in the main dataset? don't think so
    
    # with cleanup because of how TopSpin "likes" to handle file separators
    multiplets_file = os.path.abspath(sys.argv[1]);
    dataset_folder = os.path.dirname(multiplets_file)
    
    # chemical shifts of multiplets in numbers of points
    multiplets = np.loadtxt(multiplets_file)
    shift = multiplets[:, 0]
    print shift

    # relative amplitudes of multiplets
    # this should come out of the intrng file from the preparation dataset
    a = multiplets[:, 1]
    print a
    
    # Hz * pi * AQ / spectral width
    fwhm = float(sys.argv[2])

    # points to be acquired
    # chunks*2*(swh/swh1 + GRPDLY)
    npo = int(sys.argv[3])

    # chunk length in points
    # 2*(swh/swh1 + GRPDLY)
    chunk = int(sys.argv[4])

    # NUS factor
    # NusAMOUNT/100
    factor = float(sys.argv[5])


###### Simulating singlet spectrum
fid = np.zeros((npo,), dtype=complex)
n = np.linspace(0, npo-1, npo);
for i in range(len(a)):
    fid = fid + a[i]*np.exp((2*math.pi*1j*shift[i]/npo - fwhm)*n)
s = np.fft.fft(fid)
peak_for_normalization = np.argmax(np.abs(s)) 
s = s/np.abs(s[peak_for_normalization]) 


###### Generating samplings
full_number_of_chunks = int(math.floor(npo/chunk)) 
number_of_measured_chunks = int(math.floor(full_number_of_chunks*factor))
print 'Full number of chunks -', full_number_of_chunks, ', number of measured chunks -', number_of_measured_chunks
number_of_samplings = int(math.floor(special.comb(full_number_of_chunks-1, number_of_measured_chunks-1)))
print 'Number of all samplings: ', number_of_samplings

print '\nGenerating sampling schedules...'
start = time.time()
np.random.seed(1)
if number_of_samplings > samp_limit:
	print 'Using randomly selected samplings!'
	Chunks_measured = np.ones((samp_limit+2, number_of_measured_chunks))
	for i in range(samp_limit):
		Chunks_measured[i, 1:] = np.sort(np.random.permutation(np.arange(2, full_number_of_chunks)) \
			[:number_of_measured_chunks - 1])
	Chunks_measured[samp_limit, :] = np.linspace(1, number_of_measured_chunks, number_of_measured_chunks) # truncation
	Chunks_measured[samp_limit+1, :] = Chunks_measured[samp_limit, :]*math.floor(1/factor) - \
		(math.floor(1/factor) - 1) # uniform distribution
else:
	print 'Using all samplings!'
	C = np.array(list(itertools.combinations(np.linspace(2, full_number_of_chunks, full_number_of_chunks-1), \
		number_of_measured_chunks-1)))
	Chunks_measured = np.hstack((np.ones((np.shape(C)[0], 1)), C))
Chunks_measured = Chunks_measured.astype(int)
print '\nSamplings: ', Chunks_measured, np.shape(Chunks_measured)

###### Comparing samplings
s_nus = np.zeros((np.shape(Chunks_measured)[0], len(s)), dtype=complex)
max_artefact = np.zeros((np.shape(Chunks_measured)[0],))
L1 = np.zeros((np.shape(Chunks_measured)[0],))
L2 = np.zeros((np.shape(Chunks_measured)[0],))

if os.path.isfile('s_nus.npy'):
	s_nus = np.load('s_nus.npy')
	for j in range(np.shape(Chunks_measured)[0]):
	    max_artefact[j] = np.max(np.abs(s_nus[j, :]-s))
	    L1[j] = np.sum(np.abs(s_nus[j, :]-s))
	    L2[j] = np.sum(np.abs(s_nus[j, :]-s)**2)
else:
	for j in range(np.shape(Chunks_measured)[0]):
	    a = Chunks_measured[j, :] 
	    fid_nus = np.zeros(np.shape(fid), dtype=complex)
	    for k in range(np.shape(Chunks_measured)[1]):
	        fid_nus[(a[k]-1)*chunk:a[k]*chunk] = fid[(a[k]-1)*chunk:a[k]*chunk] 
	    s_nus[j, :] = np.fft.fft(fid_nus)
	    s_nus[j, :] = s_nus[j, :]/np.abs(s_nus[j, peak_for_normalization]) 

	    max_artefact[j] = np.max(np.abs(s_nus[j, :]-s))
end = time.time()
print 'done!'
print 'Time: ', end - start, 's'


C = np.argmin(max_artefact)
selected_spectrum = s_nus[C, :]
selected_fid = np.fft.ifft(s_nus[C, :])*np.abs(s_nus[C, peak_for_normalization]) 
selected_sampling = Chunks_measured[C, :] - 1
print 'Best sampling: ', selected_sampling

# can be stored in the dataset of interest - the AU program can handle the rest
nuslist_destination = os.path.join(dataset_folder, 'best_sampling.txt')
print nuslist_destination #some ugly mixed filesep handling
np.savetxt(nuslist_destination, selected_sampling, fmt='%d')